import torch
import numpy as np

from dada.model.torch_model import TorchModel


class LogSumExpModel(TorchModel):
    def __init__(self,
                 n_features: int,
                 num_polyhedron: int,
                 a_matrix: torch.Tensor,
                 b_vector: torch.Tensor,
                 mu: float,
                 init_point: torch.Tensor = None):

        self.num_polyhedron = num_polyhedron
        self.a_matrix = a_matrix
        self.b_vector = b_vector
        self.mu = mu
        super().__init__(n_features, init_point.clone())

    def loss(self):
        # Compute a_i^T x for each row in a_matrix
        dot_products = torch.matmul(self.a_matrix, self.x)  # Shape: (n,)

        # Subtract b_i from each dot product
        adjusted_values = (dot_products - self.b_vector) / self.mu  # Shape: (n,)

        # Compute the log-sum-exp
        log_sum_exp = self.mu * torch.logsumexp(adjusted_values, dim=0)

        return log_sum_exp

    def compute_value(self, point: np.ndarray):
        a_matrix = self.a_matrix.clone().detach().numpy()
        b_vector = self.b_vector.clone().detach().numpy()

        # Compute a_i^T x for each row in a_matrix
        dot_products = np.dot(a_matrix, point)  # Shape: (n,)

        # Subtract b_i from each dot product
        adjusted_values = (dot_products - b_vector) / self.mu  # Shape: (n,)

        log_sum_exp = self.mu * np.log(np.sum(np.exp(adjusted_values)))

        return log_sum_exp

    @staticmethod
    def generate_function_variables(n_features, num_polyhedron, mu):
        a_matrix = np.random.uniform(low=-1, high=1, size=(num_polyhedron, n_features))
        b_vector = np.random.uniform(low=-1, high=1, size=(num_polyhedron,))

        grad_0_numerator = np.zeros((n_features,))
        grad_0_denominator = np.zeros((n_features,))
        for i in range(num_polyhedron):
            exponent = (-1 * b_vector[i]) / mu
            grad_0_numerator += np.exp(exponent) * a_matrix[i]
            grad_0_denominator += np.exp(exponent)

        for i in range(num_polyhedron):
            a_matrix[i] = a_matrix[i] - (grad_0_numerator / grad_0_denominator)

        return (torch.from_numpy(np.array(a_matrix)).double(),
                torch.from_numpy(np.array(b_vector)).double())
